{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Slithering Snake Example\n",
    "\n",
    "This Elastica tutorial explains how to setup a Cosserat rod simulation to simulate a slithering snake. It is a more complex use case than the Timoshenko Beam example. If you have not done so, we strongly suggest you start with [this beam example](./1_Timoshenko_Beam.ipynb) as it covers many of the basics of setting up and running simulations with Elastica. \n",
    "\n",
    "This slithering snake example includes gravitational forces, friction forces, and internal muscle torques. It also introduces the use of call back functions to allow logging of simulations data for post-processing after the simulation is over. \n",
    "\n",
    "\n",
    "## Getting Started\n",
    "To set up the simulation, the first thing you need to do is import the necessary classes. As with the Timoshenko bean, we need to import wrapper functions which allow us to more easily construct different simulation systems. We also need to import a rod class, all the necessary forces to be applied, timestepping functions, and callback classes. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# import wrappers\n",
    "from elastica.wrappers import BaseSystemCollection, Constraints, Forcing, CallBacks\n",
    "\n",
    "# import rod class and forces to be applied\n",
    "from elastica.rod.cosserat_rod import CosseratRod\n",
    "from elastica.external_forces import GravityForces, MuscleTorques\n",
    "from elastica.interaction import AnisotropicFrictionalPlane\n",
    "\n",
    "# import timestepping functions\n",
    "from elastica.timestepper.symplectic_steppers import PositionVerlet\n",
    "from elastica.timestepper import integrate\n",
    "\n",
    "# import call back functions\n",
    "from elastica.callback_functions import CallBackBaseClass\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize System and Add Rod\n",
    "The first thing to do is initialize the simulator class by combining all the imported wrappers. After initializing, we will generate a rod and add it to the simulation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SnakeSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks):\n",
    "    pass\n",
    "snake_sim = SnakeSimulator()\n",
    "\n",
    "# Define rod parameters\n",
    "n_elem = 50\n",
    "start = np.array([0.0, 0.0, 0.0])\n",
    "direction = np.array([0.0, 0.0, 1.0])\n",
    "normal = np.array([0.0, 1.0, 0.0])\n",
    "base_length = 1.0\n",
    "base_radius = 0.025\n",
    "base_area = np.pi * base_radius ** 2\n",
    "density = 1000\n",
    "nu = 5.0\n",
    "E = 1e7\n",
    "poisson_ratio = 0.5\n",
    "\n",
    "# Create rod\n",
    "shearable_rod = CosseratRod.straight_rod(\n",
    "    n_elem,\n",
    "    start,\n",
    "    direction,\n",
    "    normal,\n",
    "    base_length,\n",
    "    base_radius,\n",
    "    density,\n",
    "    nu,\n",
    "    E,\n",
    "    poisson_ratio,\n",
    ")\n",
    "\n",
    "# Add rod to the snake system\n",
    "snake_sim.append(shearable_rod)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Forces to Rod\n",
    "With our rod added to the system, we need to specify the relevant forces that will be acting on the rod. For all the forces, the method of adding forces is `system_name.add_forcing_to(name_of_rod).using(type_of_force, *kwargs)` where `*kwargs` are the parameters specific to each type of force. \n",
    "\n",
    "### Gravity\n",
    "The first force to add is gravity. We specify the strength of gravity and also the direction it is pointing. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add gravitational forces\n",
    "gravitational_acc = -9.80665\n",
    "snake_sim.add_forcing_to(shearable_rod).using(\n",
    "    GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])\n",
    ")\n",
    "print('Gravity now acting on shearable rod')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Muscle Torques\n",
    "A snake generates torque throughout its body through muscle activations. While these muscle activations are generated internally by the snake, it is simpler to treat them as applied external forces, allowing us to apply them to the rod in the same manner as the other external forces. \n",
    "\n",
    "You may notice that the muscle torque parameters appear to have special values. These are optimized coefficients for a snake gait. For information about how to do this optimization, see the [snake optimization example script](../ContinuumSnakeCase/continuum_snake.py)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Define muscle torque parameters\n",
    "period = 1.0\n",
    "wave_length = 0.97\n",
    "b_coeff=np.array([17.4, 48.5, 5.4, 14.7])\n",
    "\n",
    "# Add muscle torques to the rod\n",
    "snake_sim.add_forcing_to(shearable_rod).using(\n",
    "    MuscleTorques,\n",
    "    base_length=base_length,\n",
    "    b_coeff=b_coeff,\n",
    "    period=period,\n",
    "    wave_number=2.0 * np.pi / (wave_length),\n",
    "    phase_shift=0.0,\n",
    "    rest_lengths=shearable_rod.rest_lengths,\n",
    "    ramp_up_time=period,\n",
    "    direction=normal,\n",
    "    with_spline=True,\n",
    ")\n",
    "print('Muscle torques added to the rod')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Anisotropic Friction Forces\n",
    "The last force that needs to be added is the friction force between the snake and the ground. Snakes exhibits anisotropic friction where the friction coefficient is different in different directions. You can also define both static and kinematic friction coefficients. This is accomplished by defining some small velocity threshold `slip_velocity_tol` that defines the transitions between static and kinematic friction. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define friction force parameters\n",
    "origin_plane = np.array([0.0, -base_radius, 0.0])\n",
    "normal_plane = normal\n",
    "slip_velocity_tol = 1e-8\n",
    "froude = 0.1\n",
    "mu = base_length / (period * period * np.abs(gravitational_acc) * froude)\n",
    "kinetic_mu_array = np.array([1.0 * mu, 1.5 * mu, 2.0 * mu])  # [forward, backward, sideways]\n",
    "static_mu_array = 2 * kinetic_mu_array\n",
    "\n",
    "# Add friction forces to the rod\n",
    "snake_sim.add_forcing_to(shearable_rod).using(\n",
    "    AnisotropicFrictionalPlane,\n",
    "    k=1.0,\n",
    "    nu=1e-6,\n",
    "    plane_origin=origin_plane,\n",
    "    plane_normal=normal_plane,\n",
    "    slip_velocity_tol=slip_velocity_tol,\n",
    "    static_mu_array=static_mu_array,\n",
    "    kinetic_mu_array=kinetic_mu_array,\n",
    ")\n",
    "print('Friction forces added to the rod')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Callback Function\n",
    "The simulation is now setup, but before it is run, we want to define a callback function. A callback function allows us to record time-series data throughout the simulation. If you do not define a callback function, you will only have access to the final configuration of the system. If you want to be able to analyze how the system evolves over time, it is critical that you record the appropriate quantities. \n",
    "\n",
    "To create a callback function, begin with the `CallBackBaseClass`. You can then define which state quantities you wish to record by having them appended to the `self.callback_params` dictionary as well as how often you wish to save the data by defining `skip_step`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add call backs\n",
    "class ContinuumSnakeCallBack(CallBackBaseClass):\n",
    "    \"\"\"\n",
    "    Call back function for continuum snake\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, step_skip: int, callback_params: dict):\n",
    "        CallBackBaseClass.__init__(self)\n",
    "        self.every = step_skip\n",
    "        self.callback_params = callback_params\n",
    "\n",
    "    def make_callback(self, system, time, current_step: int):\n",
    "\n",
    "        if current_step % self.every == 0:\n",
    "\n",
    "            self.callback_params[\"time\"].append(time)\n",
    "            self.callback_params[\"step\"].append(current_step)\n",
    "            self.callback_params[\"position\"].append(\n",
    "                system.position_collection.copy()\n",
    "            )\n",
    "            self.callback_params[\"velocity\"].append(\n",
    "                system.velocity_collection.copy()\n",
    "            )\n",
    "            self.callback_params[\"avg_velocity\"].append(\n",
    "                system.compute_velocity_center_of_mass()\n",
    "            )\n",
    "\n",
    "            self.callback_params[\"center_of_mass\"].append(\n",
    "                system.compute_position_center_of_mass()\n",
    "            )\n",
    "\n",
    "            return\n",
    "\n",
    "pp_list = defaultdict(list)\n",
    "snake_sim.collect_diagnostics(shearable_rod).using(\n",
    "    ContinuumSnakeCallBack, step_skip=200, callback_params=pp_list\n",
    ")\n",
    "print('Callback function added to the simulator')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the callback function added, we can now finalize the system and also define the time stepping parameters of the simulation such as the time step, final time, and time stepping algorithm to use. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "snake_sim.finalize()\n",
    "\n",
    "final_time = 5.0 * period\n",
    "dt = 4.0e-5\n",
    "total_steps = int(final_time / dt)\n",
    "print(\"Total steps\", total_steps)\n",
    "\n",
    "timestepper = PositionVerlet()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now all that is left is to run the simulation. Using the default parameters the simulation takes about 2-3 minutes to complete. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "integrate(timestepper, snake_sim, final_time, total_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Post-Process Data\n",
    "With the simulation complete, we want to analyze the simulation. Because we added a callback function, we can analyze how the snake evolves over time. All of the data from the callback function is located in the `pp_list` dictionary. Here we will use this information to compute and plot the velocity of the snake in the forward, lateral, and normal directions. We do this by using a pre-written analysis function `compute_projected_velocity`.\n",
    "\n",
    "In the plotted graph, you can see that it takes about one period for the snake to begin moving before rapidly reaching a steady gait over just 2-3 periods. We also see that the normal velocity is zero since we are only actuating the snake in a 2D plane. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def compute_and_plot_velocity(plot_params: dict, period):\n",
    "    from matplotlib import pyplot as plt\n",
    "    from analysis_functions import compute_projected_velocity\n",
    "\n",
    "    time_per_period = np.array(plot_params[\"time\"]) / period\n",
    "    avg_velocity = np.array(plot_params[\"avg_velocity\"])\n",
    "\n",
    "    [velocity_in_direction_of_rod, \n",
    "     velocity_in_rod_roll_dir, \n",
    "     avg_forward, \n",
    "     avg_lateral] = compute_projected_velocity(plot_params, period)\n",
    "    \n",
    "    print(\"average forward velocity:\", avg_forward)\n",
    "    print(\"average forward lateral:\", avg_lateral)\n",
    "\n",
    "    fig = plt.figure(figsize=(5,4), frameon=True, dpi=150)\n",
    "    ax = fig.add_subplot(111)\n",
    "    ax.grid(b=True, which=\"major\", color=\"grey\", linestyle=\"-\", linewidth = 0.25)\n",
    "    ax.plot(time_per_period[:], velocity_in_direction_of_rod[:, 2], \"r-\", label=\"forward\")\n",
    "    ax.plot(time_per_period[:], velocity_in_rod_roll_dir[:, 0], \"b-\", label=\"lateral\",)\n",
    "    ax.plot(time_per_period[:], avg_velocity[:, 1], \"g-\", label=\"normal\")\n",
    "    ax.legend(prop={\"size\": 12})\n",
    "    ax.set_ylabel('Velocity (m/s)', fontsize = 12)\n",
    "    ax.set_xlabel('Time (s)', fontsize = 12)\n",
    "    plt.show() \n",
    "\n",
    "compute_and_plot_velocity(pp_list, period)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Make Video of Snake Gait\n",
    "Because we saved data of the snake's behavior, we can make a video of its movement. The easiest way to do this is to do this is to plot the snake's position at each time that the data was recorded and then stitch these plots together to form a video. \n",
    "\n",
    "note: ffmpeg is required for matplotlib to be able to create a video. More info on ffmepg [here](https://www.ffmpeg.org/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Video\n",
    "\n",
    "def plot_video( plot_params: dict, video_name=\"video.mp4\", margin=0.2, fps=15):  \n",
    "    from matplotlib import pyplot as plt\n",
    "    import matplotlib.animation as manimation\n",
    "\n",
    "    positions_over_time = np.array(plot_params[\"position\"])\n",
    "\n",
    "    print(\"creating video -- this can take a few minutes\")\n",
    "    FFMpegWriter = manimation.writers[\"ffmpeg\"]\n",
    "    metadata = dict(title=\"Movie Test\", artist=\"Matplotlib\", comment=\"Movie support!\")\n",
    "    writer = FFMpegWriter(fps=fps, metadata=metadata)\n",
    "    fig = plt.figure()\n",
    "    plt.axis(\"equal\")\n",
    "    with writer.saving(fig, video_name, dpi=100):\n",
    "        for time in range(1, len(plot_params[\"time\"])):\n",
    "            x = positions_over_time[time][2]\n",
    "            y = positions_over_time[time][0]\n",
    "            fig.clf()\n",
    "            plt.plot(x, y, \"-\", linewidth=3)\n",
    "            plt.xlim([0 - margin, 3 + margin])\n",
    "            plt.ylim([-1.5 - margin, 1.5 + margin])\n",
    "            writer.grab_frame()\n",
    "    plt.close(fig)\n",
    "\n",
    "filename_video = \"continuum_snake.mp4\"\n",
    "plot_video(pp_list, video_name=filename_video, margin=0.2, fps=125)\n",
    "    \n",
    "Video(\"continuum_snake.mp4\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, you can also plot the position of the snake from a 3D perspective. This is most helpful is you have a simulation that consists of more than planar motion. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Video\n",
    "\n",
    "def plot_video( plot_params: dict, video_name=\"video.mp4\", margin=0.2, fps=15):  \n",
    "    from matplotlib import pyplot as plt\n",
    "    import matplotlib.animation as manimation\n",
    "    from mpl_toolkits import mplot3d\n",
    "    \n",
    "    positions_over_time = np.array(plot_params[\"position\"])\n",
    "    print(\"creating video -- this can take a few minutes\")\n",
    "    FFMpegWriter = manimation.writers[\"ffmpeg\"]\n",
    "    metadata = dict(title=\"Movie Test\", artist=\"Matplotlib\", comment=\"Movie support!\")\n",
    "    writer = FFMpegWriter(fps=fps, metadata=metadata)\n",
    "    fig = plt.figure()\n",
    "    with writer.saving(fig, video_name, dpi=100):\n",
    "        for time in range(1, len(plot_params[\"time\"])): \n",
    "            x = positions_over_time[time][0]\n",
    "            y = positions_over_time[time][2]\n",
    "            z = positions_over_time[time][1]\n",
    "            \n",
    "            fig.clf()\n",
    "            ax = fig.add_subplot( 111, projection=\"3d\" )\n",
    "            plt.plot(y, x,z, \"-\", linewidth=3)\n",
    "            ax.set_xlim(0 - margin, 3 + margin)\n",
    "            ax.set_ylim(-1.5 - margin, 1.5 + margin)\n",
    "            ax.set_zlim(0, 1)\n",
    "        \n",
    "            ax.view_init(elev=20, azim=-80)\n",
    "            writer.grab_frame()\n",
    "    plt.close(fig)\n",
    "filename_video = \"continuum_snake_3d.mp4\"\n",
    "plot_video(pp_list, video_name=filename_video, margin=0.2, fps=125)\n",
    "    \n",
    "Video(\"continuum_snake_3d.mp4\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
